In this notebook we use the segmentation models library built in PyTorch to build image segmentation models for the Magnetic tile defect dataset. This dataset is interesting because it is highly imbalanced, with less than 1% of pixels corresponding to the target class. Using the segmentation models library we can try several different loss functions, including binary cross entropy, focal losss, and Tversky loss to see their performance.

Aditionally, following an example in the segmentation models library, we'll use PyTorch Lightning to further simplify the training process in PyTorch and MLflow to log hyperparameters and metrics.

This code is built with the help of Detection of Surface Defects in Magnetic Tile Images by Dr. Mitra P. Danesh.

#!pip install torch torchvision 
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
!pip install pytorch-lightning
!pip install mlflow
Collecting torch
  Downloading torch-1.12.0-cp39-cp39-manylinux1_x86_64.whl (776.3 MB)
     |████████████████████████████████| 776.3 MB 2.1 kB/s              | 426.3 MB 9.2 MB/s eta 0:00:39 |███████████████████▍            | 470.4 MB 235 kB/s eta 0:21:42 ��██████████████████▍            | 470.5 MB 235 kB/s eta 0:21:42 ��███       | 607.5 MB 6.6 MB/s eta 0:00:26  MB 8.6 MB/s eta 0:00:01  MB 316 kB/s eta 0:00:01 
Collecting torchvision
  Downloading torchvision-0.13.0-cp39-cp39-manylinux1_x86_64.whl (19.1 MB)
     |████████████████████████████████| 19.1 MB 865 kB/s            
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.9/site-packages (from torch) (3.7.4.3)
Requirement already satisfied: numpy in /opt/conda/lib/python3.9/site-packages (from torchvision) (1.19.5)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.9/site-packages (from torchvision) (8.4.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.9/site-packages (from torchvision) (2.27.1)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision) (2.0.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision) (2021.10.8)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision) (3.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision) (1.26.8)
Installing collected packages: torch, torchvision
Successfully installed torch-1.12.0 torchvision-0.13.0
Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to /tmp/pip-req-build-ba97ov95
  Running command git clone --filter=blob:none -q https://github.com/qubvel/segmentation_models.pytorch /tmp/pip-req-build-ba97ov95
  Resolved https://github.com/qubvel/segmentation_models.pytorch to commit 740dab561ccf54a9ae4bb5bda3b8b18df3790025
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: tqdm in /opt/conda/lib/python3.9/site-packages (from segmentation-models-pytorch==0.3.0.dev0) (4.62.3)
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
     |████████████████████████████████| 58 kB 3.3 MB/s            
  Preparing metadata (setup.py) ... done
Requirement already satisfied: pillow in /opt/conda/lib/python3.9/site-packages (from segmentation-models-pytorch==0.3.0.dev0) (8.4.0)
Collecting efficientnet-pytorch==0.6.3
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: torchvision>=0.5.0 in /opt/conda/lib/python3.9/site-packages (from segmentation-models-pytorch==0.3.0.dev0) (0.13.0)
Collecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
     |████████████████████████████████| 376 kB 6.6 MB/s            
Requirement already satisfied: torch in /opt/conda/lib/python3.9/site-packages (from efficientnet-pytorch==0.6.3->segmentation-models-pytorch==0.3.0.dev0) (1.12.0)
Collecting munch
  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)
Requirement already satisfied: numpy in /opt/conda/lib/python3.9/site-packages (from torchvision>=0.5.0->segmentation-models-pytorch==0.3.0.dev0) (1.19.5)
Requirement already satisfied: requests in /opt/conda/lib/python3.9/site-packages (from torchvision>=0.5.0->segmentation-models-pytorch==0.3.0.dev0) (2.27.1)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.9/site-packages (from torchvision>=0.5.0->segmentation-models-pytorch==0.3.0.dev0) (3.7.4.3)
Requirement already satisfied: six in /opt/conda/lib/python3.9/site-packages (from munch->pretrainedmodels==0.7.4->segmentation-models-pytorch==0.3.0.dev0) (1.15.0)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision>=0.5.0->segmentation-models-pytorch==0.3.0.dev0) (2.0.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision>=0.5.0->segmentation-models-pytorch==0.3.0.dev0) (2021.10.8)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision>=0.5.0->segmentation-models-pytorch==0.3.0.dev0) (1.26.8)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.9/site-packages (from requests->torchvision>=0.5.0->segmentation-models-pytorch==0.3.0.dev0) (3.1)
Building wheels for collected packages: segmentation-models-pytorch, efficientnet-pytorch, pretrainedmodels
  Building wheel for segmentation-models-pytorch (pyproject.toml) ... done
  Created wheel for segmentation-models-pytorch: filename=segmentation_models_pytorch-0.3.0.dev0-py3-none-any.whl size=97989 sha256=6a83fe356cdef1ec3e8122f481cc929e2c2b8a4fb7d34a29f2cc946ff9642b59
  Stored in directory: /tmp/pip-ephem-wheel-cache-xw45cnat/wheels/18/f6/12/bdbad33e766c5fddfa996bfb2545d31ca070438d37e5b76408
  Building wheel for efficientnet-pytorch (setup.py) ... done
  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-py3-none-any.whl size=12421 sha256=2f18917afcad49b5cba9fbfd933b2b521984185668657c1f2a7ac7e3376fb137
  Stored in directory: /home/jovyan/.cache/pip/wheels/70/f8/49/20f330df3f946fed839df657dd2156c929d6d7b5f774d9650e
  Building wheel for pretrainedmodels (setup.py) ... done
  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-py3-none-any.whl size=60965 sha256=c5189a41cbaee763f23109c04e67d67b0848d41b32b80f4a64cd85bcb04e696a
  Stored in directory: /home/jovyan/.cache/pip/wheels/d1/3b/4e/2f3015f1ab76f34be28e04c4bcee27e8cabfa70d2eadf8bc3b
Successfully built segmentation-models-pytorch efficientnet-pytorch pretrainedmodels
Installing collected packages: munch, timm, pretrainedmodels, efficientnet-pytorch, segmentation-models-pytorch
Successfully installed efficientnet-pytorch-0.6.3 munch-2.5.0 pretrainedmodels-0.7.4 segmentation-models-pytorch-0.3.0.dev0 timm-0.4.12
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
     |████████████████████████████████| 585 kB 3.8 MB/s            
Requirement already satisfied: packaging>=17.0 in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (21.3)
Requirement already satisfied: tqdm>=4.57.0 in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (4.62.3)
Requirement already satisfied: torch>=1.8.* in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (1.12.0)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.9.2-py3-none-any.whl (419 kB)
     |████████████████████████████████| 419 kB 6.2 MB/s            
Requirement already satisfied: tensorboard>=2.2.0 in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (2.6.0)
Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (1.19.5)
Collecting typing-extensions>=4.0.0
  Downloading typing_extensions-4.2.0-py3-none-any.whl (24 kB)
Requirement already satisfied: protobuf<=3.20.1 in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (3.18.1)
Requirement already satisfied: PyYAML>=5.4 in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (6.0)
Collecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /opt/conda/lib/python3.9/site-packages (from pytorch-lightning) (2022.1.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.9/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.27.1)
Requirement already satisfied: aiohttp in /opt/conda/lib/python3.9/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.8.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.9/site-packages (from packaging>=17.0->pytorch-lightning) (3.0.6)
Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (60.5.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.8.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.6)
Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.41.1)
Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.35.0)
Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.15.0)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.6.0)
Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (2.0.1)
Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.37.1)
Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.9/site-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.6)
Requirement already satisfied: six in /opt/conda/lib/python3.9/site-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch-lightning) (1.15.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.9/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.9/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.7)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.9/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.0)
Requirement already satisfied: importlib-metadata>=4.4 in /opt/conda/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (4.10.0)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.9/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.0.10)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.9/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.9/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.26.8)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.9/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2021.10.8)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.9/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (5.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.9/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.2.0)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.9/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (4.0.2)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.9/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (21.4.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.9/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.7.2)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.9/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.2.0)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.1.1)
Installing collected packages: typing-extensions, torchmetrics, pyDeprecate, pytorch-lightning
  Attempting uninstall: typing-extensions
    Found existing installation: typing-extensions 3.7.4.3
    Uninstalling typing-extensions-3.7.4.3:
      Successfully uninstalled typing-extensions-3.7.4.3
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.6.2 requires typing-extensions~=3.7.4, but you have typing-extensions 4.2.0 which is incompatible.
Successfully installed pyDeprecate-0.3.2 pytorch-lightning-1.6.4 torchmetrics-0.9.2 typing-extensions-4.2.0
Collecting mlflow
  Downloading mlflow-1.27.0-py3-none-any.whl (17.9 MB)
     |████████████████████████████████| 17.9 MB 10.7 MB/s            
Collecting Flask
  Downloading Flask-2.1.2-py3-none-any.whl (95 kB)
     |████████████████████████████████| 95 kB 3.1 MB/s             
Requirement already satisfied: requests>=2.17.3 in /opt/conda/lib/python3.9/site-packages (from mlflow) (2.27.1)
Requirement already satisfied: protobuf>=3.12.0 in /opt/conda/lib/python3.9/site-packages (from mlflow) (3.18.1)
Requirement already satisfied: entrypoints in /opt/conda/lib/python3.9/site-packages (from mlflow) (0.3)
Collecting docker>=4.0.0
  Downloading docker-5.0.3-py2.py3-none-any.whl (146 kB)
     |████████████████████████████████| 146 kB 11.8 MB/s            
Collecting sqlparse>=0.3.1
  Downloading sqlparse-0.4.2-py3-none-any.whl (42 kB)
     |████████████████████████████████| 42 kB 798 kB/s             
Collecting querystring-parser
  Downloading querystring_parser-1.2.4-py2.py3-none-any.whl (7.9 kB)
Requirement already satisfied: sqlalchemy>=1.4.0 in /opt/conda/lib/python3.9/site-packages (from mlflow) (1.4.29)
Requirement already satisfied: packaging in /opt/conda/lib/python3.9/site-packages (from mlflow) (21.3)
Collecting gunicorn
  Downloading gunicorn-20.1.0-py3-none-any.whl (79 kB)
     |████████████████████████████████| 79 kB 5.0 MB/s             
Collecting databricks-cli>=0.8.7
  Downloading databricks-cli-0.17.0.tar.gz (81 kB)
     |████████████████████████████████| 81 kB 7.4 MB/s             
  Preparing metadata (setup.py) ... done
Requirement already satisfied: pytz in /opt/conda/lib/python3.9/site-packages (from mlflow) (2021.3)
Requirement already satisfied: cloudpickle in /opt/conda/lib/python3.9/site-packages (from mlflow) (2.0.0)
Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.9/site-packages (from mlflow) (6.0)
Requirement already satisfied: importlib-metadata!=4.7.0,>=3.7.0 in /opt/conda/lib/python3.9/site-packages (from mlflow) (4.10.0)
Collecting prometheus-flask-exporter
  Downloading prometheus_flask_exporter-0.20.2-py3-none-any.whl (18 kB)
Collecting gitpython>=2.1.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
     |████████████████████████████████| 181 kB 10.0 MB/s            
Requirement already satisfied: alembic in /opt/conda/lib/python3.9/site-packages (from mlflow) (1.7.5)
Requirement already satisfied: scipy in /opt/conda/lib/python3.9/site-packages (from mlflow) (1.7.3)
Requirement already satisfied: numpy in /opt/conda/lib/python3.9/site-packages (from mlflow) (1.19.5)
Requirement already satisfied: click>=7.0 in /opt/conda/lib/python3.9/site-packages (from mlflow) (8.0.3)
Requirement already satisfied: pandas in /opt/conda/lib/python3.9/site-packages (from mlflow) (1.3.5)
Requirement already satisfied: pyjwt>=1.7.0 in /opt/conda/lib/python3.9/site-packages (from databricks-cli>=0.8.7->mlflow) (2.3.0)
Requirement already satisfied: oauthlib>=3.1.0 in /opt/conda/lib/python3.9/site-packages (from databricks-cli>=0.8.7->mlflow) (3.1.1)
Collecting tabulate>=0.7.7
  Downloading tabulate-0.8.10-py3-none-any.whl (29 kB)
Requirement already satisfied: six>=1.10.0 in /opt/conda/lib/python3.9/site-packages (from databricks-cli>=0.8.7->mlflow) (1.15.0)
Requirement already satisfied: websocket-client>=0.32.0 in /opt/conda/lib/python3.9/site-packages (from docker>=4.0.0->mlflow) (1.2.3)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)
     |████████████████████████████████| 63 kB 1.9 MB/s             
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.9/site-packages (from importlib-metadata!=4.7.0,>=3.7.0->mlflow) (3.7.0)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.9/site-packages (from requests>=2.17.3->mlflow) (2.0.10)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.9/site-packages (from requests>=2.17.3->mlflow) (3.1)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.9/site-packages (from requests>=2.17.3->mlflow) (2021.10.8)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.9/site-packages (from requests>=2.17.3->mlflow) (1.26.8)
Requirement already satisfied: greenlet!=0.4.17 in /opt/conda/lib/python3.9/site-packages (from sqlalchemy>=1.4.0->mlflow) (1.1.2)
Requirement already satisfied: Mako in /opt/conda/lib/python3.9/site-packages (from alembic->mlflow) (1.1.6)
Requirement already satisfied: Werkzeug>=2.0 in /opt/conda/lib/python3.9/site-packages (from Flask->mlflow) (2.0.1)
Collecting itsdangerous>=2.0
  Downloading itsdangerous-2.1.2-py3-none-any.whl (15 kB)
Requirement already satisfied: Jinja2>=3.0 in /opt/conda/lib/python3.9/site-packages (from Flask->mlflow) (3.0.3)
Requirement already satisfied: setuptools>=3.0 in /opt/conda/lib/python3.9/site-packages (from gunicorn->mlflow) (60.5.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.9/site-packages (from packaging->mlflow) (3.0.6)
Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.9/site-packages (from pandas->mlflow) (2.8.2)
Requirement already satisfied: prometheus-client in /opt/conda/lib/python3.9/site-packages (from prometheus-flask-exporter->mlflow) (0.12.0)
Collecting smmap<6,>=3.0.1
  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.9/site-packages (from Jinja2>=3.0->Flask->mlflow) (2.0.1)
Building wheels for collected packages: databricks-cli
  Building wheel for databricks-cli (setup.py) ... done
  Created wheel for databricks-cli: filename=databricks_cli-0.17.0-py3-none-any.whl size=141960 sha256=af13856c6d0a21dde1b8a0f8573bd6969c313c1425c065d9b919ba8f270fd1de
  Stored in directory: /home/jovyan/.cache/pip/wheels/d5/b6/71/c3052c82e4a88dc658dd2616b944e130c1d0ff3f77e8f02df7
Successfully built databricks-cli
Installing collected packages: smmap, itsdangerous, tabulate, gitdb, Flask, sqlparse, querystring-parser, prometheus-flask-exporter, gunicorn, gitpython, docker, databricks-cli, mlflow
Successfully installed Flask-2.1.2 databricks-cli-0.17.0 docker-5.0.3 gitdb-4.0.9 gitpython-3.1.27 gunicorn-20.1.0 itsdangerous-2.1.2 mlflow-1.27.0 prometheus-flask-exporter-0.20.2 querystring-parser-1.2.4 smmap-5.0.0 sqlparse-0.4.2 tabulate-0.8.10
import os
import random

import torch
import numpy as np
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger

# set the random seeds for reproducibility
random.seed(42)
torch.manual_seed(0)
np.random.seed(0)

Loading data

from torch.utils.data import random_split
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
from glob import glob

First download the data locally.

%%capture
!wget -O data.zip https://github.com/abin24/Magnetic-tile-defect-datasets./archive/master.zip
!unzip data.zip
!mv Magnetic-tile-defect-datasets.-master data
classes =['Blowhole', 'Crack','Free'] # classes/labels
image_paths = []

for c in classes:
    # retreive image file paths recursively
    images_found = glob('data/MT_' + c + '/Imgs/*.jpg',recursive=True) 
    if c== 'Free': # undersample the free class.
        image_paths.extend( images_found[:80] )
    else:
        image_paths.extend( images_found )        
    
random.shuffle(image_paths)
len(image_paths)
252

Dataset

Writing helper class for data extraction, tranformation and preprocessing
https://pytorch.org/docs/stable/data. Also see the binary segmentation intro in the segmentation models library for more details on designing the Dataset class.

import torchvision.transforms.functional as TF
import random

class SurfaceDefectDetectionDataset(Dataset):
    def __init__(self, image_path_list, use_transform=False):
        super().__init__()
        self.image_path_list = image_path_list
        self.use_transform = use_transform
        
    def transform(self, image, target):
        if random.random() < 0.5:
            image = TF.hflip(image)
            target = TF.hflip(target)
    
        if random.random() < 0.5:
            image = TF.vflip(image)
            target = TF.vflip(target)
    
        angle = random.choice([0, -90, 90, 180])
        image, target = TF.rotate(image, angle), TF.rotate(target, angle)
    
        return image, target
        
    def __len__(self):
        return len(self.image_path_list)
    
    def __getitem__(self, idx):
        # Open the image file which is in jpg  
        image = Image.open(self.image_path_list[idx])
        # The mask is in png. 
        # Use the image path, and change its extension to png to get the mask's path.
        mask = Image.open(os.path.splitext(self.image_path_list[idx])[0]+'.png') 
        
        # resize the images.
        image, mask = TF.resize(image, (320,320)), TF.resize(mask, (320,320))
        
        # Perform augmentation if required.
        if self.use_transform:
            image, mask = self.transform(image, mask)
        
        # Transform the image and mask PILs to torch tensors. 
        image, mask = TF.to_tensor(image), TF.to_tensor(mask)
        
        # Threshold mask, threshold limit is 0.5
        mask = (mask >= 0.5)*(1.0)
        
        #return the image and mask pair tensors
        return image, mask
split_len = int(0.8*len(image_paths))
train_dataset = SurfaceDefectDetectionDataset(image_paths[:split_len], use_transform = True)
test_dataset = SurfaceDefectDetectionDataset(image_paths[split_len:], use_transform = False)
train_dataset, val_dataset = random_split(train_dataset, [int(split_len*0.9), split_len - int(split_len*0.9)], generator=torch.Generator().manual_seed(1))
print('Length of train dataset: ', len(train_dataset))
print('Length of validation dataset: ', len(val_dataset))
print('Length of test dataset: ', len(test_dataset))
Length of train dataset:  180
Length of validation dataset:  21
Length of test dataset:  51

Let's take a look at the dataset

import matplotlib.pyplot as plt
import random

sample_img, sample_msk = train_dataset[random.choice(range(len(train_dataset)))]
plt.subplot(1,2,1)
plt.title("Sample from trainining set")
plt.axis("off")
plt.imshow(sample_img.squeeze(), cmap='gray')
plt.subplot(1,2,2)
plt.axis("off")
plt.imshow(sample_msk.squeeze(), cmap='gray')
plt.show()

sample_img, sample_msk = val_dataset[random.choice(range(len(val_dataset)))]
plt.subplot(1,2,1)
plt.title("Sample from validation set")
plt.axis("off")
plt.imshow(sample_img.squeeze(), cmap='gray')
plt.subplot(1,2,2)
plt.axis("off")
plt.imshow(sample_msk.squeeze(), cmap='gray')
plt.show()

sample_img, sample_msk = test_dataset[random.choice(range(len(test_dataset)))]
plt.subplot(1,2,1)
plt.title("Sample from test set")
plt.axis("off")
plt.imshow(sample_img.squeeze(), cmap='gray')
plt.subplot(1,2,2)
plt.axis("off")
plt.imshow(sample_msk.squeeze(), cmap='gray')
plt.show()

Take a look at more samples from the train set.

for i in range(5):
    sample_img, sample_msk = train_dataset[random.choice(range(len(train_dataset)))]
    plt.subplot(1,2,1)
    plt.title("Image")
    plt.axis("off")
    plt.imshow(sample_img.squeeze(), cmap='gray')
    plt.subplot(1,2,2)
    plt.title("Mask")
    plt.axis("off")
    plt.imshow(sample_msk.squeeze(), cmap='gray')
    plt.show()

Find the weight of positive and negative pixels.

The number of positive pixels is less than 1% of the total, showing that the dataset is highly imbalanced.

positive_weight = 0
negative_weight = 0
total_pixels = 0
img_shape = train_dataset[0][0].shape
for _, target in train_dataset:
    positive_weight += (target >= 0.5).sum().item()
    negative_weight += (target < 0.5).sum().item()
    total_pixels += (img_shape[1] * img_shape[2])
positive_weight /= total_pixels
negative_weight /= total_pixels
print('positive weight = ',positive_weight, '\tnegative weight = ', negative_weight)
positive weight =  0.0022352430555555554 	negative weight =  0.9977647569444444

Create model and train

from itertools import islice

def show_predictions_from_batch(model, dataloader, batch_num=0, limit = None):
    """
        Method to visualize model predictions from batch batch_num.
        
        Show a maximum of limit images.
    """
    batch = next(islice(iter(dataloader), batch_num, None), None) # Selects the nth item from dataloader, returning None if not possible.
    images, masks = batch

    with torch.no_grad():
        model.eval()

        logits = model(images)

    pr_masks = logits.sigmoid()
    pr_masks = (pr_masks >= 0.5)*1

    for i, (image, gt_mask, pr_mask) in enumerate(zip(images, masks, pr_masks)):
        if limit and i == limit:
            break
        fig = plt.figure(figsize=(15,4))

        ax = fig.add_subplot(1,3,1)
        ax.imshow(image.squeeze(), cmap='gray')
        ax.set_title("Image")
        ax.axis("off")

        ax = fig.add_subplot(1,3,2)
        ax.imshow(gt_mask.squeeze(), cmap='gray')
        ax.set_title("Ground truth")
        ax.axis("off")

        ax = fig.add_subplot(1,3,3)
        ax.imshow(pr_mask.squeeze(), cmap='gray')
        ax.set_title("Predicted mask")
        ax.axis("off")

We'll create a PyTorch Lightning module to help streamline the training process. In the class initalization, it uses the segmentation models library, via the call to smp.create_model, to build a PyTorch model which operates on one channel images for binary segmentation. Many state of the art models are possible with the segmentation models library. However, we'll typically use the Unet with resnet34 backbone.

We'll also use the segmentation models library to monitor the intersection over union metric. This will give us a much better indicator of model quality than the accuracy.

class SurfaceDefectModel(pl.LightningModule):

    def __init__(self, arch, encoder_name, loss = "SoftBCEWithLogitsLoss" , **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, encoder_weights = None, in_channels=1, classes=1, **kwargs
        )

        self.arch = arch
        self.encoder_name = encoder_name
        
        self.loss_name = loss
        if loss == "DiceLoss":
            self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        elif loss == "TverskyLoss":
            self.loss_fn = smp.losses.TverskyLoss(smp.losses.BINARY_MODE, from_logits=True, alpha=0.3,beta=0.7)
        elif loss == "FocalLoss":
            self.loss_fn = smp.losses.FocalLoss(smp.losses.BINARY_MODE)              
        else:
            self.loss_fn = smp.losses.SoftBCEWithLogitsLoss()
            
        self.printed_run_id = None
        self.run_id = None
        
    def forward(self, image):
        return self.model(image)

    def shared_step(self, batch, stage):
        
        image = batch[0]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch[1]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])
        
        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        accuracy = smp.metrics.accuracy(tp, fp, fn, tn)
        
        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
            f"{stage}_accuracy": accuracy,
            f"{stage}_loss": torch.tensor([x["loss"].item() for x in outputs]).mean()
        }
        
        # Log the metrics
        #for key, val in metrics.items():
        #    self.logger.experiment.log_metric(self.logger.run_id ,key,  val.mean().item(), step=self.current_epoch)
        self.logger.log_metrics({key: val.mean().item() for key, val in metrics.items() }, step=self.current_epoch)
        
        # only record the loss in mlflow
        del metrics[f"{stage}_loss"]
        if not self.printed_run_id:
            print(self.logger.run_id )
            self.printed_run_id = True
            
        # This will be available in tensorboard.
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)
    
    def on_fit_end(self):
        # Log hyperparameters to mlflow.
        self.logger.experiment.log_param(self.logger.run_id ,"arch", self.arch)
        self.logger.experiment.log_param(self.logger.run_id ,"encoder_name", self.encoder_name)
        self.logger.experiment.log_param(self.logger.run_id ,"loss", self.loss_name)
        self.run_id = self.logger.run_id
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
valid_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
model = SurfaceDefectModel("Unet", "resnet34")

Sanity check the model by showing its predictions.

show_predictions_from_batch(model, train_loader, batch_num=0, limit=1)

We'll use the ModelCheckpoint callback from PyTorch lightning to save the best model, as measured by the intersection over union metric.

from pytorch_lightning.callbacks import ModelCheckpoint
from pathlib import Path

checkpoint_callback = ModelCheckpoint(
    monitor="valid_dataset_iou",
    dirpath="./models",
    filename= f"surface_defect_{model.arch}_{model.encoder_name}_{model.loss_name}",
    save_top_k=3,
    mode="max",
)

# Add the model directory if it it doesn't exist
Path("./models").mkdir(exist_ok=True)

Now with the help of PyTorch lightning, we can train and log to MLflow, with a few lines of code.

mlf_logger = MLFlowLogger(experiment_name="lightning_logs")
trainer = pl.Trainer(
    gpus=1, 
    max_epochs=200,
    callbacks=[checkpoint_callback],
    logger=mlf_logger,
)

trainer.fit(
    model, 
    train_dataloaders=train_loader, 
    val_dataloaders=valid_loader,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type                  | Params
--------------------------------------------------
0 | model   | Unet                  | 24.4 M
1 | loss_fn | SoftBCEWithLogitsLoss | 0     
--------------------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.720    Total estimated model params size (MB)
Experiment with name lightning_logs not found. Creating it.
0947cdf639374feab6d819b7ab2cfb0c
/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1933: PossibleUserWarning: The number of training batches (23) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(

Load the best model to analyze its performance.

model.load_from_checkpoint(f"models/surface_defect_{model.arch}_{model.encoder_name}_{model.loss_name}.ckpt", arch = model.arch, encoder_name= model.encoder_name, loss = model.loss_fn.__class__.__name__)
SurfaceDefectModel(
  (model): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (5): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (decoder): UnetDecoder(
      (center): Identity()
      (blocks): ModuleList(
        (0): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (1): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (2): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (3): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (4): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
      )
    )
    (segmentation_head): SegmentationHead(
      (0): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Identity()
      (2): Activation(
        (activation): Identity()
      )
    )
  )
  (loss_fn): SoftBCEWithLogitsLoss()
)
trainer.validate(model, dataloaders=valid_loader, verbose=False)
[{'valid_per_image_iou': 0.0018078746506944299,
  'valid_dataset_iou': 0.003512033959850669,
  'valid_accuracy': 0.995513916015625}]

Visualize the model performance on the validation set.

for i in range(len(valid_loader)):
    show_predictions_from_batch(model, valid_loader, batch_num=i)
/tmp/ipykernel_23/3728964761.py:23: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  fig = plt.figure(figsize=(15,4))

Analyze best saved model on the Test set

test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)
trainer.test(model, dataloaders=test_loader, verbose=False)
[{'test_per_image_iou': 0.08222604542970657,
  'test_dataset_iou': 0.010361794382333755,
  'test_accuracy': 0.996726393699646}]

Finally, visualize the model performance on the test set.

for i in range(len(test_loader)):
    show_predictions_from_batch(model, test_loader, batch_num=i)
/tmp/ipykernel_23/3728964761.py:23: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  fig = plt.figure(figsize=(15,4))